{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Helper functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pathlib\n",
    "import multiprocessing\n",
    "import yaml\n",
    "import sys\n",
    "import warnings\n",
    "\n",
    "from commonroad.common.file_reader import CommonRoadFileReader\n",
    "from commonroad.common.solution_writer import CommonRoadSolutionWriter, VehicleModel, VehicleType, CostFunction\n",
    "\n",
    "\n",
    "def parse_vehicle_model(model):\n",
    "    if model == 'PM':\n",
    "        cr_model = VehicleModel.PM\n",
    "    elif model == 'ST':\n",
    "        cr_model = VehicleModel.ST\n",
    "    elif model == 'KS':\n",
    "        cr_model = VehicleModel.KS\n",
    "    elif model == 'MB':\n",
    "        cr_model = VehicleModel.MB\n",
    "    else:\n",
    "        raise ValueError('Selected vehicle model is not valid: {}.'.format(model))\n",
    "    return cr_model\n",
    "\n",
    "\n",
    "def parse_vehicle_type(type):\n",
    "    if type == 'FORD_ESCORT':\n",
    "        cr_type = VehicleType.FORD_ESCORT\n",
    "        cr_type_id = 1\n",
    "    elif type == 'BMW_320i':\n",
    "        cr_type = VehicleType.BMW_320i\n",
    "        cr_type_id = 2\n",
    "    elif type == 'VW_VANAGON':\n",
    "        cr_type = VehicleType.VW_VANAGON\n",
    "        cr_type_id = 3\n",
    "    else:\n",
    "        raise ValueError('Selected vehicle type is not valid: {}.'.format(type))\n",
    "        \n",
    "    return cr_type, cr_type_id\n",
    "\n",
    "\n",
    "def parse_cost_function(cost):\n",
    "    if cost == 'JB1':\n",
    "        cr_cost = CostFunction.JB1\n",
    "    elif cost == 'SA1':\n",
    "        cr_cost = CostFunction.SA1\n",
    "    elif cost == 'WX1':\n",
    "        cr_cost = CostFunction.WX1\n",
    "    elif cost == 'SM1':\n",
    "        cr_cost = CostFunction.SM1\n",
    "    elif cost == 'SM2':\n",
    "        cr_cost = CostFunction.SM2\n",
    "    elif cost == 'SM3':\n",
    "        cr_cost = CostFunction.SM3\n",
    "    else:\n",
    "        raise ValueError('Selected cost function is not valid: {}.'.format(cost))\n",
    "    return cr_cost\n",
    "\n",
    "\n",
    "def call_trajectory_planner(queue, function, scenario, planning_problem_set, vehicle_type_id):\n",
    "    queue.put(function(scenario, planning_problem_set, vehicle_type_id))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# open config file\n",
    "with open('batch_processing_config.yaml', 'r') as stream:\n",
    "    try:\n",
    "        settings = yaml.load(stream)\n",
    "    except yaml.YAMLError as exc:\n",
    "        print(exc)\n",
    "\n",
    "# get planning wrapper function\n",
    "sys.path.append(os.getcwd() + os.path.dirname(settings['trajectory_planner_path']))\n",
    "module = __import__(settings['trajectory_planner_module_name'])\n",
    "function = getattr(module, settings['trajectory_planner_function_name'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Started processing scenario USA_Lanker-1_6_T-1.xml\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Process trajectory_planner:\n",
      "Traceback (most recent call last):\n",
      "  File \"/home/edmond/Softwares/others/anaconda3/envs/cr37/lib/python3.7/multiprocessing/process.py\", line 297, in _bootstrap\n",
      "    self.run()\n",
      "  File \"/home/edmond/Softwares/others/anaconda3/envs/cr37/lib/python3.7/multiprocessing/process.py\", line 99, in run\n",
      "    self._target(*self._args, **self._kwargs)\n",
      "TypeError: call_trajectory_planner() takes 4 positional arguments but 5 were given\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Planning finished.\n"
     ]
    }
   ],
   "source": [
    "# iterate through scenarios\n",
    "for filename in os.listdir(settings['input_path']):\n",
    "    if not filename.endswith('.xml'):\n",
    "        continue\n",
    "    fullname = os.path.join(settings['input_path'], filename)\n",
    "\n",
    "    print(\"Started processing scenario {}\".format(filename))\n",
    "    scenario, planning_problem_set = CommonRoadFileReader(fullname).open()\n",
    "\n",
    "    # get settings for each scenario\n",
    "    if scenario.benchmark_id in settings.keys():\n",
    "        # specific\n",
    "        vehicle_model = parse_vehicle_model(settings[scenario.benchmark_id]['vehicle_model'])\n",
    "        vehicle_type,vehicle_type_id = parse_vehicle_type(settings[scenario.benchmark_id]['vehicle_type'])\n",
    "        cost_function = parse_cost_function(settings[scenario.benchmark_id]['cost_function'])\n",
    "    else:\n",
    "        # default\n",
    "        vehicle_model = parse_vehicle_model(settings['default']['vehicle_model'])\n",
    "        vehicle_type, vehicle_type_id = parse_vehicle_type(settings['default']['vehicle_type'])\n",
    "        cost_function = parse_cost_function(settings['default']['cost_function'])\n",
    "        \n",
    "    queue = multiprocessing.Queue()\n",
    "    # create process, pass in required arguements\n",
    "    p = multiprocessing.Process(target=call_trajectory_planner, name=\"trajectory_planner\",\n",
    "                                args=(queue, function, scenario, planning_problem_set, vehicle_type_id))\n",
    "    # start planning\n",
    "    p.start()\n",
    "    \n",
    "    # wait till process ends or skip if timed out\n",
    "    p.join(timeout=settings['timeout'])\n",
    "\n",
    "    if p.is_alive():\n",
    "        print(\"Trajectory planner timeout.\")\n",
    "        p.terminate()\n",
    "        p.join()\n",
    "        solution_trajectories = {}\n",
    "    else:\n",
    "        print(\"Planning finished.\")\n",
    "        solution_trajectories = queue.get()\n",
    "\n",
    "    # create path for solutions\n",
    "    pathlib.Path(settings['output_path']).mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "    error = False\n",
    "    cr_solution_writer = CommonRoadSolutionWriter(settings['output_path'], \n",
    "                                                  scenario.benchmark_id, \n",
    "                                                  scenario.dt,\n",
    "                                                  vehicle_type, \n",
    "                                                  vehicle_model, \n",
    "                                                  cost_function)\n",
    "    \n",
    "    # inspect whether all planning problems are solved\n",
    "    for planning_problem_id, planning_problem in planning_problem_set.planning_problem_dict.items():\n",
    "        if planning_problem_id not in solution_trajectories.keys():\n",
    "            warnings.warn('Solution for planning problem with ID={} is not provided for scenario {}. Solution skipped.'.format(\n",
    "                planning_problem_id, filename))\n",
    "            error = True\n",
    "            break\n",
    "        else:\n",
    "            cr_solution_writer.add_solution_trajectory(\n",
    "                solution_trajectories[planning_problem_id], planning_problem_id)\n",
    "    if not error:\n",
    "        cr_solution_writer.write_to_file(overwrite=settings['overwrite'])\n",
    "\n",
    "print(\"=========================================================\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (cr37)",
   "language": "python",
   "name": "cr37"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
